{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting SGD_Training.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile SGD_Training.py\n",
    "\n",
    "from __future__ import print_function, division\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.optim as optim\n",
    "import time\n",
    "import copy\n",
    "import os\n",
    "import pdb\n",
    "import shutil\n",
    "from torch.utils.data import DataLoader\n",
    "#end of imports\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appending to SGD_Training.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile -a SGD_Training.py\n",
    "\n",
    "\n",
    "def train_model(model, criterion, optimizer, lr_scheduler,lr,dset_loaders,dset_sizes,use_gpu, num_epochs,exp_dir='./',resume=''):\n",
    "    print('dictoinary length'+str(len(dset_loaders)))\n",
    "    #reg_params=model.reg_params\n",
    "    since = time.time()\n",
    "\n",
    "    best_model = model\n",
    "    best_acc = 0.0\n",
    "    if os.path.isfile(resume):\n",
    "        print(\"=> loading checkpoint '{}'\".format(resume))\n",
    "        checkpoint = torch.load(resume)\n",
    "        start_epoch = checkpoint['epoch']\n",
    "        #best_prec1 = checkpoint['best_prec1']\n",
    "        #model = checkpoint['model']\n",
    "        model.load_state_dict(checkpoint['state_dict'])\n",
    "        #modelx = checkpoint['model']\n",
    "        #model.reg_params=modelx.reg_params\n",
    "        print('load')\n",
    "        optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "        #pdb.\n",
    "        #model.reg_params=reg_params\n",
    "        #del model.reg_params\n",
    "        print(\"=> loaded checkpoint '{}' (epoch {})\"\n",
    "              .format(resume, checkpoint['epoch']))\n",
    "    else:\n",
    "            start_epoch=0\n",
    "            print(\"=> no checkpoint found at '{}'\".format(resume))\n",
    "    \n",
    "    print(str(start_epoch))\n",
    "    #pdb.set_trace()\n",
    "    for epoch in range(start_epoch, num_epochs):\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        \n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                optimizer = lr_scheduler(optimizer, epoch,lr)\n",
    "                model.train(True)  # Set model to training mode\n",
    "            else:\n",
    "                model.train(False)  # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "\n",
    "            # Iterate over data.\n",
    "            for data in dset_loaders[phase]:\n",
    "                # get the inputs\n",
    "                inputs, labels = data\n",
    "                inputs=inputs.squeeze()\n",
    "                # wrap them in Variable\n",
    "                if use_gpu:\n",
    "                    inputs, labels = Variable(inputs.cuda()), \\\n",
    "                        Variable(labels.cuda())\n",
    "                else:\n",
    "                    inputs, labels = Variable(inputs), Variable(labels)\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "                model.zero_grad()\n",
    "                # forward\n",
    "                outputs = model(inputs)\n",
    "                _, preds = torch.max(outputs.data, 1)\n",
    "                loss = criterion(outputs, labels)\n",
    "\n",
    "                # backward + optimize only if in training phase\n",
    "                if phase == 'train':\n",
    "                    loss.backward()\n",
    "                    #print('step')\n",
    "                    optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                running_loss += loss.data[0]\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "\n",
    "            epoch_loss = running_loss / dset_sizes[phase]\n",
    "            epoch_acc = running_corrects / dset_sizes[phase]\n",
    "\n",
    "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
    "                phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            # deep copy the model\n",
    "            if phase == 'val' and epoch_acc > best_acc:\n",
    "                del outputs\n",
    "                del labels\n",
    "                del inputs\n",
    "                del loss\n",
    "                del preds\n",
    "                best_acc = epoch_acc\n",
    "                #best_model = copy.deepcopy(model)\n",
    "                torch.save(model,os.path.join(exp_dir, 'best_model.pth.tar'))\n",
    "                \n",
    "        #epoch_file_name=exp_dir+'/'+'epoch-'+str(epoch)+'.pth.tar'\n",
    "        epoch_file_name=exp_dir+'/'+'epoch'+'.pth.tar'\n",
    "        save_checkpoint({\n",
    "            'epoch': epoch + 1,\n",
    "            'epoch_acc':epoch_acc,\n",
    "            'arch': 'alexnet',\n",
    "            'model': model,\n",
    "            'state_dict': model.state_dict(),\n",
    "            'optimizer' : optimizer.state_dict(),\n",
    "                },epoch_file_name)\n",
    "        print()\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(\n",
    "        time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val Acc: {:4f}'.format(best_acc))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appending to SGD_Training.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile -a SGD_Training.py\n",
    "\n",
    "def save_checkpoint(state, filename='checkpoint.pth.tar'):\n",
    "    #best_model = copy.deepcopy(model)\n",
    "    torch.save(state, filename)\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
